#!/usr/bin/env python
# -*- coding: utf-8 -*-
import sys
import requests

SHARD_TASK_NAME = "test-shard"
ILLEGAL_CHAR_TASK_NAME = "t-Ë!s`t"
SOURCE1_NAME = "mysql-01"
SOURCE2_NAME = "mysql-02"


API_ENDPOINT = "http://127.0.0.1:8361/api/v1/tasks"


def start_task_failed():
    task = {
        "name": "test",
        "task_mode": "all",
        "shard_mode": "pessimistic_xxd",  # pessimistic_xxd is not a valid shard mode
        "meta_schema": "dm-meta",
        "enhance_online_schema_change": True,
        "on_duplicate": "error",
        "target_config": {
            "host": "127.0.0.1",
            "port": 4000,
            "user": "root",
            "password": "",
        },
        "table_migrate_rule": [
            {
                "source": {
                    "source_name": SOURCE1_NAME,
                    "schema": "openapi",
                    "table": "*",
                },
                "target": {"schema": "openapi", "table": "t"},
            },
            {
                "source": {
                    "source_name": SOURCE2_NAME,
                    "schema": "openapi",
                    "table": "*",
                },
                "target": {"schema": "openapi", "table": "t"},
            },
        ],
        "source_config": {
            "source_conf": [
                {"source_name": SOURCE1_NAME},
                {"source_name": SOURCE2_NAME},
            ],
        },
    }
    resp = requests.post(url=API_ENDPOINT, json={"remove_meta": True, "task": task})
    print("start_task_failed resp=", resp.json())
    assert resp.status_code == 400


def start_noshard_task_success(task_name, tartget_table_name=""):
    task = {
        "name": task_name,
        "task_mode": "all",
        "shard_mode": "pessimistic",
        "meta_schema": "dm-meta",
        "enhance_online_schema_change": True,
        "on_duplicate": "error",
        "target_config": {
            "host": "127.0.0.1",
            "port": 4000,
            "user": "root",
            "password": "",
        },
        "table_migrate_rule": [
            {
                "source": {
                    "source_name": SOURCE1_NAME,
                    "schema": "openapi",
                    "table": "*",
                },
                "target": {"schema": "openapi", "table": tartget_table_name},
            },
            {
                "source": {
                    "source_name": SOURCE2_NAME,
                    "schema": "openapi",
                    "table": "*",
                },
                "target": {"schema": "openapi", "table": tartget_table_name},
            },
        ],
        "source_config": {
            "source_conf": [
                {"source_name": SOURCE1_NAME},
                {"source_name": SOURCE2_NAME},
            ],
        },
    }
    resp = requests.post(url=API_ENDPOINT, json={"remove_meta": True, "task": task})
    print("start_noshard_task_success resp=", resp.json())
    assert resp.status_code == 201


def start_shard_task_success():
    task = {
        "name": SHARD_TASK_NAME,
        "task_mode": "all",
        "shard_mode": "pessimistic",
        "meta_schema": "dm-meta",
        "enhance_online_schema_change": True,
        "on_duplicate": "error",
        "target_config": {
            "host": "127.0.0.1",
            "port": 4000,
            "user": "root",
            "password": "",
        },
        "table_migrate_rule": [
            {
                "source": {
                    "source_name": SOURCE1_NAME,
                    "schema": "openapi",
                    "table": "*",
                },
                "target": {"schema": "openapi", "table": "t"},
                "binlog_filter_rule": ["rule-1"],
            },
            {
                "source": {
                    "source_name": SOURCE2_NAME,
                    "schema": "openapi",
                    "table": "*",
                },
                "target": {"schema": "openapi", "table": "t"},
                "binlog_filter_rule": ["rule-2"],
            },
        ],
        "source_config": {
            "source_conf": [
                {"source_name": SOURCE1_NAME},
                {"source_name": SOURCE2_NAME},
            ],
        },
        "binlog_filter_rule": {
            "rule-1": {
                "ignore_event": ["delete"],
            },
            "rule-2": {
                "ignore_sql": ["alter table .* add column `aaa` int"],
            },
        },
    }
    resp = requests.post(url=API_ENDPOINT, json={"remove_meta": True, "task": task})
    print("start_shard_task_success resp=", resp.json())
    assert resp.status_code == 201


def get_task_status_failed(task_name):
    url = API_ENDPOINT + "/" + task_name + "/status"
    resp = requests.get(url=url)
    print("get_task_status_failed resp=", resp.json())
    assert resp.status_code == 400


def get_illegal_char_task_status_failed():
    # task name contains illegal char but api server can handle it.
    # return 400 is because of the task is not started.
    url = API_ENDPOINT + "/" + ILLEGAL_CHAR_TASK_NAME + "/status"
    resp = requests.get(url=url)
    print("get_illegal_char_task_status_failed resp=", resp.json())
    assert resp.status_code == 400
    if sys.version_info.major == 2:
        # need decode in python2
        assert ILLEGAL_CHAR_TASK_NAME.decode("utf-8") in resp.json()["error_msg"]
    else:
        assert ILLEGAL_CHAR_TASK_NAME in resp.json()["error_msg"]


def get_task_status_success(task_name, total):
    url = API_ENDPOINT + "/" + task_name + "/status"
    resp = requests.get(url=url)
    data = resp.json()
    assert resp.status_code == 200
    print("get_task_status_success resp=", data)
    assert data["total"] == int(total)


def get_task_status_success_but_worker_meet_error(task_name, total):
    url = API_ENDPOINT + "/" + task_name + "/status"
    resp = requests.get(url=url)
    data = resp.json()
    assert resp.status_code == 200
    print("get_task_status_success_but_worker_meet_error resp=", data)
    assert data["total"] == int(total)
    for status in data["data"]:
        assert status["name"] == task_name
        assert status["error_msg"] is not None


def get_task_list(task_count):
    url = API_ENDPOINT
    resp = requests.get(url=url)
    data = resp.json()
    assert resp.status_code == 200
    print("get_task_list resp=", data)
    assert data["total"] == int(task_count)


def get_task_list_with_status(task_count, task_name, status_count):
    url = API_ENDPOINT + "?with_status=true"
    resp = requests.get(url=url)
    data = resp.json()
    assert resp.status_code == 200
    print("get_task_list_with_status resp=", data)

    assert data["total"] == int(task_count)
    find_task = False
    for task in data["data"]:
        if task["name"] == task_name:
            find_task = True
            assert len(task["status_list"]) == int(status_count)
    assert find_task


def pause_task_success(task_name, source_name):
    url = API_ENDPOINT + "/" + task_name + "/pause"
    resp = requests.post(
        url=url,
        json=[
            source_name,
        ],
    )
    if resp.status_code != 200:
        print("pause_task_failed resp=", resp.json())
    assert resp.status_code == 200


def resume_task_success(task_name, source_name):
    url = API_ENDPOINT + "/" + task_name + "/resume"
    resp = requests.post(
        url=url,
        json=[
            source_name,
        ],
    )
    if resp.status_code != 200:
        print("resume_task_failed resp=", resp.json())
    assert resp.status_code == 200


def operate_schema_and_table_success(task_name, source_name, schema_name, table_name):
    schema_url = API_ENDPOINT + "/" + task_name + "/sources/" + source_name + "/schemas"
    schema_resp = requests.get(url=schema_url)
    assert schema_resp.status_code == 200
    print("get_task_schema_success schema resp=", schema_resp.json())
    assert len(schema_resp.json()) > 0

    schema_list = schema_resp.json()
    assert schema_name in schema_list
    table_url = schema_url + "/" + schema_name
    table_resp = requests.get(url=table_url)
    assert table_resp.status_code == 200
    print("get_task_schema_success table resp=", table_resp.json())
    table_list = table_resp.json()
    assert table_name in table_list

    single_table_url = table_url + "/" + table_name
    create_table_resp = requests.get(url=single_table_url)
    assert create_table_resp.status_code == 200
    create_table = create_table_resp.json()
    print("get_task_schema_success create table resp=", create_table)
    assert create_table["table_name"] == table_name
    assert create_table["schema_name"] == schema_name
    assert table_name in create_table["schema_create_sql"]

    # delete table
    resp = requests.delete(url=single_table_url)
    assert resp.status_code == 204

    # after delete, no table in schema
    resp = requests.get(url=table_url)
    assert resp.status_code == 200
    print("get_task_schema_success table resp=", resp.json())
    assert len(resp.json()) == 0

    # add table back again
    set_table_data = {
        "sql_content": "CREATE TABLE openapi.t1(i TINYINT, j INT UNIQUE KEY);",
        "flush": True,
        "sync": True,
    }
    resp = requests.put(url=single_table_url, json=set_table_data)
    assert resp.status_code == 200
    table_list = requests.get(url=table_url).json()
    print("get_task_schema_success table resp=", table_list)
    assert len(table_list) == 1


def stop_task_failed(task_name):
    resp = requests.delete(url=API_ENDPOINT + "/" + task_name)
    print("stop_task_failed resp=", resp.json())
    assert resp.status_code == 400


def stop_task_success(task_name):
    resp = requests.delete(url=API_ENDPOINT + "/" + task_name)
    assert resp.status_code == 204
    print("stop_task_success")


def create_task_template_success(task_name, target_table_name=""):
    url = API_ENDPOINT + "/templates"
    task = {
        "name": task_name,
        "task_mode": "all",
        "shard_mode": "pessimistic",
        "meta_schema": "dm-meta",
        "enhance_online_schema_change": True,
        "on_duplicate": "error",
        "target_config": {
            "host": "127.0.0.1",
            "port": 4000,
            "user": "root",
            "password": "",
        },
        "table_migrate_rule": [
            {
                "source": {
                    "source_name": SOURCE1_NAME,
                    "schema": "openapi",
                    "table": "*",
                },
                "target": {"schema": "openapi", "table": target_table_name},
            },
            {
                "source": {
                    "source_name": SOURCE2_NAME,
                    "schema": "openapi",
                    "table": "*",
                },
                "target": {"schema": "openapi", "table": target_table_name},
            },
        ],
        "source_config": {
            "source_conf": [
                {"source_name": SOURCE1_NAME},
                {"source_name": SOURCE2_NAME},
            ],
        },
    }
    resp = requests.post(url=url, json=task)
    print("create_task_template_success resp=", resp.json())
    assert resp.status_code == 201


def create_task_template_failed():
    url = API_ENDPOINT + "/templates"
    task = {
        "name": "test",
        "task_mode": "all",
        "shard_mode": "pessimistic_xxd",  # pessimistic_xxd is not a valid shard mode
        "meta_schema": "dm-meta",
        "enhance_online_schema_change": True,
        "on_duplicate": "error",
        "target_config": {
            "host": "127.0.0.1",
            "port": 4000,
            "user": "root",
            "password": "",
        },
        "table_migrate_rule": [
            {
                "source": {
                    "source_name": SOURCE1_NAME,
                    "schema": "openapi",
                    "table": "*",
                },
                "target": {"schema": "openapi", "table": "t"},
            },
            {
                "source": {
                    "source_name": SOURCE2_NAME,
                    "schema": "openapi",
                    "table": "*",
                },
                "target": {"schema": "openapi", "table": "t"},
            },
        ],
        "source_config": {
            "source_conf": [
                {"source_name": SOURCE1_NAME},
                {"source_name": SOURCE2_NAME},
            ],
        },
    }
    resp = requests.post(url=url, json=task)
    print("create_task_template_failed resp=", resp.json())
    assert resp.status_code == 400


def list_task_template(count):
    url = API_ENDPOINT + "/templates"
    resp = requests.get(url=url)
    data = resp.json()
    assert resp.status_code == 200
    print("list_task_template resp=", data)
    assert data["total"] == int(count)


def import_task_template(success_count, failed_count):
    url = API_ENDPOINT + "/templates/import"
    resp = requests.post(url=url)
    data = resp.json()
    print("import_task_template resp=", data)
    assert resp.status_code == 202
    assert len(data["success_task_list"]) == int(success_count)
    assert len(data["failed_task_list"]) == int(failed_count)


def get_task_template(name):
    url = API_ENDPOINT + "/templates/" + name
    resp = requests.get(url=url)
    data = resp.json()
    assert resp.status_code == 200
    print("get_task_template resp=", data)
    assert data["name"] == name


def update_task_template_success(name, task_mode):
    url = API_ENDPOINT + "/templates/" + name

    # get task template first
    task = requests.get(url=url).json()
    task["task_mode"] = task_mode
    resp = requests.put(url=url, json=task)
    print("update_task_template_success resp=", resp.json())
    assert resp.status_code == 200

    # update task template success
    assert requests.get(url=url).json()["task_mode"] == task_mode


def delete_task_template(name):
    url = API_ENDPOINT + "/templates/" + name
    resp = requests.delete(url=url)
    assert resp.status_code == 204
    print("delete_task_template")


def check_noshard_task_dump_status_success(task_name, total):
    url = API_ENDPOINT + "/" + task_name + "/status"
    resp = requests.get(url=url)
    data = resp.json()
    assert resp.status_code == 200
    print("check_dump_status_success resp=", data)
    assert data["data"][0]["dump_status"]["finished_bytes"] == int(total)


if __name__ == "__main__":
    FUNC_MAP = {
        "start_task_failed": start_task_failed,
        "start_noshard_task_success": start_noshard_task_success,
        "start_shard_task_success": start_shard_task_success,
        "stop_task_failed": stop_task_failed,
        "pause_task_success": pause_task_success,
        "resume_task_success": resume_task_success,
        "stop_task_success": stop_task_success,
        "get_task_list": get_task_list,
        "get_task_list_with_status": get_task_list_with_status,
        "get_task_status_failed": get_task_status_failed,
        "get_illegal_char_task_status_failed": get_illegal_char_task_status_failed,
        "get_task_status_success": get_task_status_success,
        "get_task_status_success_but_worker_meet_error": get_task_status_success_but_worker_meet_error,
        "operate_schema_and_table_success": operate_schema_and_table_success,
        "create_task_template_success": create_task_template_success,
        "create_task_template_failed": create_task_template_failed,
        "list_task_template": list_task_template,
        "import_task_template": import_task_template,
        "get_task_template": get_task_template,
        "update_task_template_success": update_task_template_success,
        "delete_task_template": delete_task_template,
        "check_noshard_task_dump_status_success": check_noshard_task_dump_status_success,
    }

    func = FUNC_MAP[sys.argv[1]]
    if len(sys.argv) >= 2:
        func(*sys.argv[2:])
    else:
        func()
